{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Vector-space models: Static representations from contextual models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "__author__ = \"Christopher Potts\"\n",
    "__version__ = \"CS224u, Stanford, Spring 2022\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Contents\n",
    "\n",
    "1. [Overview](#Overview)\n",
    "1. [General set-up](#General-set-up)\n",
    "1. [Loading Transformer models](#Loading-Transformer-models)\n",
    "1. [The basics of tokenizing](#The-basics-of-tokenizing)\n",
    "1. [The basics of representations](#The-basics-of-representations)\n",
    "1. [The decontextualized approach](#The-decontextualized-approach)\n",
    "    1. [Basic example](#Basic-example)\n",
    "    1. [Creating a full VSM](#Creating-a-full-VSM)\n",
    "1. [The aggregated approach](#The-aggregated-approach)\n",
    "1. [Some related work](#Some-related-work)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Overview\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Can we get good static representations of words from models (like BERT) that supply only contextual representations? On the one hand, contextual models are very successful across a wide range of tasks, in large part because they are trained for a long time on a lot of data. This should be a boon for VSMs as we've designed them so far. On the other hand, the goal of having static representations might seem to be at odds with how these models process examples and represent examples. Part of the point is to obtain different representations for words depending on the context in which they occur, and a hallmark of the training procedure is that it processes sequences rather than individual words.\n",
    "\n",
    "[Bommasani et al. (2020)](https://www.aclweb.org/anthology/2020.acl-main.431) make a significant step forward in our understanding of these issues. Ultimately, they arrive at a positive answer: excellent static word representations can be obtained from contextual models. They explore two strategies for achieving this:\n",
    "\n",
    "1. __The decontextualized approach__: just process individual words as though they were isolated texts. Where a word consists of multiple tokens in the model, pool them with a function like mean or max.\n",
    "1. __The aggregrated approach__: process lots and lots of texts containing the words of interest. As before, pool sub-word tokens, and also pool across all the pooled representations.\n",
    "\n",
    "As Bommasani et al. say, the decontextualized approach \"presents an unnatural input\" – these models were not trained on individual words, but rather on longer sequences, so the individual words are infrequent kinds of inputs at best (and unattested as far as the model is concerned if the special boundary tokens [CLS] and [SEP] are not included). However, in practice, Bommasani et al. achieve very impressive results with this approach on word similarity/relatedness tasks.\n",
    "\n",
    "The aggregrated approach is even better, but it requires more work and involves more decisions relating to which texts are processed.\n",
    "\n",
    "This notebook briefly explores both of these approaches, with the goal of making it easy for you to apply these methods in [the associated homework and bakeoff](hw_wordrelatedness.ipynb)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## General set-up\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import pandas as pd\n",
    "import torch\n",
    "from transformers import BertModel, BertTokenizer\n",
    "from transformers import RobertaModel, RobertaTokenizer\n",
    "\n",
    "import utils\n",
    "import vsm"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Make sure you've downloaded [the data distribution for this course](http://web.stanford.edu/class/cs224u/data/data.tgz), unpacked it, and placed it in the current directory (or wherever you point `DATA_HOME` to below)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "DATA_HOME = os.path.join('data', 'vsmdata')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set all the random seeds for reproducibility:\n",
    "\n",
    "utils.fix_random_seeds()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The `transformers` library does a lot of logging. To avoid ending up with a cluttered notebook, I am changing the logging level. You might want to skip this as you scale up to building production systems, since the logging is very good – it gives you a lot of insights into what the models and code are doing."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "import logging\n",
    "logger = logging.getLogger()\n",
    "logger.level = logging.ERROR"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Loading Transformer models\n",
    "\n",
    "To start, let's get a feel for the basic API that `transformers` provides. The first step is specifying the pretrained parameters we'll be using:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "bert_weights_name = 'bert-base-cased'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "There are lots other options for pretrained weights. See [this Hugging Face directory](https://huggingface.co/models)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we specify a tokenizer and a model that match both each other and our choice of pretrained weights:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "bert_tokenizer = BertTokenizer.from_pretrained(bert_weights_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight']\n",
      "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
     ]
    }
   ],
   "source": [
    "bert_model = BertModel.from_pretrained(bert_weights_name)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## The basics of tokenizing"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "It's illuminating to see what the tokenizer does to example texts:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "example_text = \"Bert knows Snuffleupagus\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Simple tokenization:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['Bert', 'knows', 'S', '##nu', '##ffle', '##up', '##agu', '##s']"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bert_tokenizer.tokenize(example_text)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The `encode` method maps individual strings to indices into the underlying embedding used by the model:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[101, 15035, 3520, 156, 14787, 13327, 4455, 28026, 1116, 102]"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ex_ids = bert_tokenizer.encode(example_text, add_special_tokens=True)\n",
    "\n",
    "ex_ids"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can get a better feel for what these representations are like by mapping the indices back to \"words\":"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['[CLS]',\n",
       " 'Bert',\n",
       " 'knows',\n",
       " 'S',\n",
       " '##nu',\n",
       " '##ffle',\n",
       " '##up',\n",
       " '##agu',\n",
       " '##s',\n",
       " '[SEP]']"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bert_tokenizer.convert_ids_to_tokens(ex_ids)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Those are all the essential ingredients for working with these parameters in Hugging Face. Of course, the library has a lot of other functionality, but the above suffices for our current application."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## The basics of representations"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To obtain the representations for a batch of examples, we use the `forward` method of the model, as follows:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "    reps = bert_model(torch.tensor([ex_ids]), output_hidden_states=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The return value `reps` is a special `transformers` class that holds a lot of representations. If we want just the final output representations for each token, we use `last_hidden_state`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([1, 10, 768])"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "reps.last_hidden_state.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The shape indicates that our batch has 1 example, with 10 tokens, and each token is represented by a vector of dimensionality 768. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Aside: Hugging Face `transformers` models also have a `pooler_output` value. For BERT, this corresponds to the output representation above the [CLS] token, which is often used as a summary representation for the entire sequence. However, __we cannot use `pooler_output` in the current context__, as `transformers` adds new randomized parameters on top of it, to facilitate fine-tuning. If we want the [CLS] representation, we need to use `reps.last_hidden_state[:, 0]`."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally, if we want access to the output representations from each layer of the model, we use `hidden_states`. This will be `None` unless we set `output_hidden_states=True` when using the `forward` method, as above. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "13"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(reps.hidden_states)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The length 13 corresponds to the initial embedding layer (layer 0) and the 12 layers of this BERT model."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The final layer in `hidden_states` is identical to `last_hidden_state`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([1, 10, 768])"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "reps.hidden_states[-1].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.equal(reps.hidden_states[-1], reps.last_hidden_state)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## The decontextualized approach\n",
    "\n",
    "As discussed above, Bommasani et al. (2020) define and explore two general strategies for obtaining static representations for word using a model like BERT. The simpler one involves processing individual words and, where they correspond to multiple tokens, pooling those token representations into a single vector using an operation like mean."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Basic example"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To begin to see what this is like in practice, we'll use the method `vsm.hf_encode`, which maps texts to their ids, taking care to use `unk_token` for texts that can't otherwise be processed by the model."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Where a word corresponds to just one token in the vocabulary, it will get mapped to a single id:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['puppy']"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bert_tokenizer.tokenize('puppy')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[21566]])"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "vsm.hf_encode(\"puppy\", bert_tokenizer)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As we saw above, some words map to multiple tokens:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['s', '##nu', '##ffle', '##up', '##agu', '##s']"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bert_tokenizer.tokenize('snuffleupagus')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[  188, 14787, 13327,  4455, 28026,  1116]])"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "subtok_ids = vsm.hf_encode(\"snuffleupagus\", bert_tokenizer)\n",
    "\n",
    "subtok_ids"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, the function `vsm.hf_represent` will map a batch of ids to their representations in a user-supplied model, at a specified layer in that model:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([1, 6, 768])"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "subtok_reps = vsm.hf_represent(subtok_ids, bert_model, layer=-1)\n",
    "\n",
    "subtok_reps.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The shape here: 1 example containing 6 (sub-word) tokens, each of dimension 768. With `layer=-1`, we obtain the final output repreentation from the entire model."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The final step is to pool together the 6 tokens. Here, we can use a variety of operations; [Bommasani et al. 2020](https://www.aclweb.org/anthology/2020.acl-main.431) find that `mean` is the best overall:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([1, 768])"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "subtok_pooled = vsm.mean_pooling(subtok_reps)\n",
    "\n",
    "subtok_pooled.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The function `vsm.mean_pooling` is simple `torch.mean` with `axis=1`. There are also predefined functions `vsm.max_pooling`, `vsm.min_pooling`, and `vsm.last_pooling` (representation for the final token)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Creating a full VSM\n",
    "\n",
    "Now we want to scale the above process to a large vocabulary, so that we can create a full VSM. The function `vsm.create_subword_pooling_vsm` makes this easy. To start, we get the vocabulary from one of our count VSMs (all of which have the same vocabulary):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "vsm_index = pd.read_csv(\n",
    "    os.path.join(DATA_HOME, 'yelp_window5-scaled.csv.gz'),\n",
    "    usecols=[0], index_col=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "vocab = list(vsm_index.index)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['):', ');', '..', '...', ':(']"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "vocab[: 5]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And then we use `vsm.create_subword_pooling_vsm`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 13min 30s, sys: 3min 3s, total: 16min 33s\n",
      "Wall time: 3min 37s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "pooled_df = vsm.create_subword_pooling_vsm(\n",
    "    vocab, bert_tokenizer, bert_model, layer=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The result, `pooled_df`, is a `pd.DataFrame` with its index given by `vocab`. This can be used directly in the word relatedness evaluations that are central the homework and associated bakeoff."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(6000, 768)"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pooled_df.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>0</th>\n",
       "      <th>1</th>\n",
       "      <th>2</th>\n",
       "      <th>3</th>\n",
       "      <th>4</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>):</th>\n",
       "      <td>-0.125594</td>\n",
       "      <td>0.592184</td>\n",
       "      <td>0.248018</td>\n",
       "      <td>0.195743</td>\n",
       "      <td>0.284454</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>);</th>\n",
       "      <td>-0.396890</td>\n",
       "      <td>-0.001702</td>\n",
       "      <td>0.027441</td>\n",
       "      <td>0.144946</td>\n",
       "      <td>0.943918</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>..</th>\n",
       "      <td>-0.247551</td>\n",
       "      <td>0.435480</td>\n",
       "      <td>0.195532</td>\n",
       "      <td>0.283153</td>\n",
       "      <td>-0.496592</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>-0.238037</td>\n",
       "      <td>0.506108</td>\n",
       "      <td>0.245705</td>\n",
       "      <td>0.388796</td>\n",
       "      <td>-0.772245</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>:(</th>\n",
       "      <td>-0.710794</td>\n",
       "      <td>0.723624</td>\n",
       "      <td>0.444736</td>\n",
       "      <td>0.670010</td>\n",
       "      <td>0.604062</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "            0         1         2         3         4\n",
       "):  -0.125594  0.592184  0.248018  0.195743  0.284454\n",
       ");  -0.396890 -0.001702  0.027441  0.144946  0.943918\n",
       "..  -0.247551  0.435480  0.195532  0.283153 -0.496592\n",
       "... -0.238037  0.506108  0.245705  0.388796 -0.772245\n",
       ":(  -0.710794  0.723624  0.444736  0.670010  0.604062"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pooled_df.iloc[: 5, :5]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This approach, and the associated code, should work generally for all Hugging Face Transformer-based models. Bommasani et al. (2020) provide a lot of guidance when it comes to how the model, the layer choice, and the pooling function interact."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## The aggregated approach"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The aggregated is also straightfoward to implement given the above tool. To start, we can create a map from vocabulary items into their sequences of ids:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "vocab_ids = {w: vsm.hf_encode(w, bert_tokenizer)[0] for w in vocab}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, let's assume we have a corpus of texts that contain the words of interest:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "corpus = [\n",
    "    \"This is a sailing example\",\n",
    "    \"It's fun to go sailing!\",\n",
    "    \"We should go sailing.\",\n",
    "    \"I'd like to go sailing and sailing\",\n",
    "    \"This is merely an example\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The following embeds every corpus example, keeping `layer=1` representations:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "corpus_ids = [vsm.hf_encode(text, bert_tokenizer)\n",
    "              for text in corpus]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "corpus_reps = [vsm.hf_represent(ids, bert_model, layer=1)\n",
    "               for ids in corpus_ids]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally, we define a convenience function for finding all the occurrences of a sublist in a larger list:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "def find_sublist_indices(sublist, mainlist):\n",
    "    indices = []\n",
    "    length = len(sublist)\n",
    "    for i in range(0, len(mainlist)-length+1):\n",
    "        if mainlist[i: i+length] == sublist:\n",
    "            indices.append((i, i+length))\n",
    "    return indices"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For example:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[(0, 2), (4, 6)]"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "find_sublist_indices([1,2], [1, 2, 3, 0, 1, 2, 3])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And here's an example using our `vocab_ids` and `corpus`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [],
   "source": [
    "sailing = vocab_ids['sailing']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "sailing_reps = []\n",
    "\n",
    "for ids, reps in zip(corpus_ids, corpus_reps):\n",
    "    offsets = find_sublist_indices(sailing, ids.squeeze(0))\n",
    "    for (start, end) in offsets:\n",
    "        pooled = vsm.mean_pooling(reps[:, start: end])\n",
    "        sailing_reps.append(pooled)\n",
    "\n",
    "sailing_rep = torch.mean(torch.cat(sailing_reps), axis=0).squeeze(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([768])"
      ]
     },
     "execution_count": 38,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sailing_rep.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The above building blocks could be used as the basis for an original system and bakeoff entry for this unit. The major question is probably which data to use for the corpus."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Some related work\n",
    "\n",
    "1. [Ethayarajh (2019)](https://www.aclweb.org/anthology/D19-1006/) uses dimensionality reduction techniques (akin to LSA) to derive static representations from contextual models, and explores layer-wise variation in detailed, with findings that are likely to align with your experiences using the above techniques.\n",
    "\n",
    "1. [Akbik et al (2019)](https://www.aclweb.org/anthology/N19-1078/) explore techniques similar to those of Bommasani et al. specifically for the supervised task of named entity recognition.\n",
    "\n",
    "1. [Wang et al. (2020](https://arxiv.org/pdf/1911.02929.pdf) learn static representations from contextual ones using techniques adapted from the word2vec model."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
